from curses import raw
import os
import argparse
import torch
import csv
from datetime import datetime
from huggingface_hub import login
from tqdm import tqdm
from settings import *
from transformers import AutoTokenizer, AutoModelForCausalLM

def extract_number(input_string):
    for char in input_string:
        if char.isdigit():
            return char

def parse_result(result):
    rating = result.strip()
    rating = extract_number(rating)
    int_rating = int(rating)
    return int_rating


def get_completion(model, tokenizer, messages, past_key_values, max_new_tokens, temperature):
    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)


    if temperature == 0:
        outputs = model.generate(
            input_ids,
            do_sample=False,
            top_p=None,
            temperature=None,
            max_new_tokens=max_new_tokens,
            return_dict_in_generate=True,
            past_key_values=past_key_values
        )
    else:
        outputs = model.generate(
            input_ids,
            do_sample=False,
            top_p=None,
            temperature=None,
            max_new_tokens=max_new_tokens,
            return_dict_in_generate=True,
            past_key_values=past_key_values
        )
    response = tokenizer.decode(outputs.sequences[0][input_ids.shape[-1]:], skip_special_tokens=True)


    new_past_key_values = []
    for i in range(len(outputs['past_key_values'])):
        kv = outputs['past_key_values'][i]
        new_kv = (kv[0][:,:,:input_ids.shape[1],:], kv[1][:,:,:input_ids.shape[1],:])
        new_past_key_values.append(new_kv)
        
    new_past_key_values = tuple(new_past_key_values)

    return response, outputs['past_key_values']

def get_settings(args):
    if args.setting == "basic":
        return BasicSetting(args)
    if args.setting == "bwvr":
        return BwvrSetting(args)
    if args.setting == "names":
        return NamesSetting(args)
    if args.setting == "demographic":
        return DemographicSetting(args)
    if args.setting == "personas":
        return PersonasSetting(args)
        
def run(args):
    setting = get_settings(args)
    
    if args.model_type == 'llama':
        model_id = f"meta-llama/Meta-Llama-3.1-{args.model_size}B-Instruct"
    elif args.model_type == 'gemma':
        model_id = f'google/gemma-2-{args.model_size}b-it'


    results_path = f'{args.output_dir}/serial_results/{args.model_type}_{args.model_size}b_results'

    if not os.path.exists(results_path):
        os.makedirs(results_path)
        
    temp_str = str(args.temp).replace('.', '')

    if args.setting in ['basic', 'bwvr']:
        csv_filename = f"{results_path}/{args.setting}_{args.model_type}{args.model_size}_{args.gender}_{temp_str}.csv"
        required_iterations = 150
    else:
        csv_filename = f"{results_path}/{args.setting}_{args.model_type}{args.model_size}_{temp_str}.csv"
        required_iterations = 300

    login("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")
    model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16)
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, max_length=4096, padding_side='left')
    
    if args.model_type == "llama":
        tokenizer.pad_token_id = tokenizer.eos_token_id
        
        
    if args.setting in ['basic', 'bwvr']:
        required_iterations = 150
    else:
        required_iterations = 300
        
    assert not os.path.exists(csv_filename), "Results file already exists!"


    # Write the CSV header and results
    with open(csv_filename, "w", newline="", encoding="utf-8") as csvfile:
        csvwriter = csv.writer(csvfile)
        # Write the header
        setting.write_first_csv_row(csvwriter)
        date = datetime.now()
        date_str = date.strftime("%x")
        
        # Iterate 150 times
        for i in tqdm(range(required_iterations)):
            identifier = setting.get_identifer(i)
            descriptions = setting.choose_decriptions(args, identifier)

            # Concatenate all descriptions into a single string
            #description_text = "\n".join(f"{j+1}. {desc}" for j, desc in enumerate(descriptions))#.join(descriptions)

            # Construct the prompt
            
            messages = []
            
            print(f"Iteration {i + 1}")
            if identifier != None:
                print(f"Identifier: {identifier}")
                
            rating_list = []
            past_key_values = None
    
            for j, description in enumerate(descriptions):
                if j == 0:
                    prompt = setting.generate_prompt(identifier, f"{j+1}. {description}")
                else:
                    prompt = f"{j+1}. {description}"
                
                messages.append({"role": "user", "content": prompt})
                #print(f"Prompt: {prompt}")
                
                retries = 1
                success = False

                while not success and retries <= 10:
                    prev_past_key_values = past_key_values
                    
                    try:
                        response, past_key_values = get_completion(model, tokenizer, messages, past_key_values, max_new_tokens=24, temperature=args.temp)
                    except:
                        response, past_key_values = get_completion(model, tokenizer, messages, None, max_new_tokens=24, temperature=args.temp)
                    
                    model.generation_config.cache_implementation = None
                    # Print the iteration, demographic string, and answers
                
                    print(f"{j+1}. Answer: {response}")
                    
                    try:
                        rating = parse_result(response)
                        assert rating >= 1 and rating <= 6, "Invalid rating"
                        messages.append({"role": "assistant", "content": f" {rating}\n"})

                        rating_list.append(rating)
                        success = True
                    except:
                        print("Error with model output")
                        retries += 1
                        past_key_values = prev_past_key_values
                        
                if not success:
                    print("Error with output")
                    break
                
            if len(descriptions) == len(rating_list):
                for k in range(len(descriptions)):
                    setting.write_csv_row(args, csvwriter, date_str, i, k, descriptions[k], rating_list[k], identifier)
            else:
                print(f"Falied to parse result for iteration {i}")
            

    print(f"Results saved to: {csv_filename}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_size", default=8, type=int)
    parser.add_argument("--model_type", default='llama', type=str)
    parser.add_argument("--temp", default=0.0, type=float)
    parser.add_argument("--gender", default='female', type=str)
    parser.add_argument("--setting", default='demographic', type=str)
    parser.add_argument("--output_dir", default='.', type=str)


    args = parser.parse_args()
    run(args)